{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Span 设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/read\n"
     ]
    }
   ],
   "source": [
    "%cd ../..\n",
    "import set_env\n",
    "from tools.tag_span import _set_span, _create_span, _verify_structural_equal_with_span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "from tvm import relay, testing\n",
    "from tvm.relay.frontend.common import StrAttrsDict, set_span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_key_is_present():\n",
    "    attrs = StrAttrsDict({\"a\": 1})\n",
    "    assert attrs.has_attr(\"a\")\n",
    "\n",
    "\n",
    "def test_key_is_not_present():\n",
    "    attrs = StrAttrsDict({\"a\": 1})\n",
    "    assert not attrs.has_attr(\"b\")\n",
    "\n",
    "test_key_is_present()\n",
    "test_key_is_not_present()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 Span pass 开关"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res(should_fill):\n",
    "    if should_fill:\n",
    "        with testing.enable_span_filling():\n",
    "            return set_span(relay.var(\"x\", shape=(1, 64, 56, 56)), \"x_var\")\n",
    "    else:\n",
    "        with testing.disable_span_filling():\n",
    "            return set_span(relay.var(\"x\", shape=(1, 64, 56, 56)), \"x_var\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "disable = relay.var(\"x\", shape=(1, 64, 56, 56))\n",
    "enable = relay.var(\"x\", shape=(1, 64, 56, 56), span=_create_span(\"x_var\"))\n",
    "\n",
    "_verify_structural_equal_with_span(_res(False), disable)\n",
    "_verify_structural_equal_with_span(_res(True), enable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "free_var %x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */;\n",
      "%x\n"
     ]
    }
   ],
   "source": [
    "print(enable)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "应该标记所有没有 Span 的表达式，并在表达式被标记为 Span 时停止。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试内建元组 Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    a = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"a\"))\n",
    "    b = relay.const(np.zeros([1, 1, 1]), dtype=\"int64\")\n",
    "    return set_span(tuple([a, b]), \"tuple\")\n",
    "\n",
    "def _golden():\n",
    "    a = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"a\"))\n",
    "    b = relay.const(np.zeros([1, 1, 1]), dtype=\"int64\", span=_create_span(\"tuple\"))\n",
    "    return tuple([a, b])\n",
    "\n",
    "res_tuple, golden_tuple = _res(), _golden()\n",
    "assert len(res_tuple) == len(golden_tuple)\n",
    "for i in range(len(res_tuple)):\n",
    "    _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "meta[relay.Constant][0] /* span=a:0:0 */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(golden_tuple[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试内建列表 Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    a = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"a\"))\n",
    "    b = relay.const(np.zeros([1, 1, 1]), dtype=\"int64\")\n",
    "    t = relay.Tuple([a, b])\n",
    "    t_a = relay.TupleGetItem(t, 0)\n",
    "    t_b = relay.TupleGetItem(t, 1)\n",
    "    return set_span([t_a, t_b], \"list\")\n",
    "\n",
    "def _golden():\n",
    "    a = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"a\"))\n",
    "    b = relay.const(np.zeros([1, 1, 1]), dtype=\"int64\", span=_create_span(\"list\"))\n",
    "    t = relay.Tuple([a, b], span=_create_span(\"list\"))\n",
    "    t_a = relay.TupleGetItem(t, 0, span=_create_span(\"list\"))\n",
    "    t_b = relay.TupleGetItem(t, 1, span=_create_span(\"list\"))\n",
    "    return [t_a, t_b]\n",
    "\n",
    "res_list, golden_list = _res(), _golden()\n",
    "assert len(res_list) == len(golden_list)\n",
    "for i in range(len(res_list)):\n",
    "    _verify_structural_equal_with_span(res_list[i], golden_list[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.var` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = set_span(relay.var(\"x\", shape=(1, 64, 56, 56)), \"x_var\")\n",
    "x_expected = relay.var(\"x\", shape=(1, 64, 56, 56), span=_create_span(\"x_var\"))\n",
    "_verify_structural_equal_with_span(x, x_expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.const` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype=\"int64\"), \"const_c\")\n",
    "c_expected = relay.const(\n",
    "    np.ones([64, 64, 3, 3]), dtype=\"int64\", span=_create_span(\"const_c\")\n",
    ")\n",
    "_verify_structural_equal_with_span(c, c_expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.Call` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    x = set_span(relay.var(\"x\", shape=(1, 64, 56, 56)), \"x_var\")\n",
    "    w = relay.const(np.ones([64, 64, 3, 3]), dtype=\"int64\")\n",
    "    y = set_span(\n",
    "        relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), \"conv2d\"\n",
    "    )\n",
    "    return relay.Function([x], y)\n",
    "\n",
    "def _golden():\n",
    "    x = relay.var(\"x\", shape=(1, 64, 56, 56), span=_create_span(\"x_var\"))\n",
    "    w = relay.const(np.ones([64, 64, 3, 3]), dtype=\"int64\", span=_create_span(\"conv2d\"))\n",
    "    y = _set_span(\n",
    "        relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), \"conv2d\"\n",
    "    )\n",
    "    return relay.Function([x], y)\n",
    "\n",
    "_verify_structural_equal_with_span(_res(), _golden())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {\n",
      "  nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* span=conv2d:0:0 */\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(_golden())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.Tuple` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    a = set_span(relay.const(np.ones([1, 1, 1]), dtype=\"int64\"), \"a\")\n",
    "    b = relay.const(np.ones([1, 1, 1]), dtype=\"int64\")\n",
    "    t = set_span(relay.Tuple([a, b]), \"t\")\n",
    "    return relay.Function([], t)\n",
    "\n",
    "def _golden():\n",
    "    a = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"a\"))\n",
    "    b = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"t\"))\n",
    "    t = relay.Tuple([a, b], span=_create_span(\"t\"))\n",
    "    return relay.Function([], t)\n",
    "\n",
    "_verify_structural_equal_with_span(_res(), _golden())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.TupleGetItem` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    a = set_span(relay.const(np.ones([1, 1, 1]), dtype=\"int64\"), \"a\")\n",
    "    b = relay.const(np.ones([1, 1, 1]), dtype=\"int64\")\n",
    "    t = relay.Tuple([a, b])\n",
    "    i = set_span(relay.TupleGetItem(t, 0), \"i\")\n",
    "    return relay.Function([], i)\n",
    "\n",
    "def _golden():\n",
    "    a = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"a\"))\n",
    "    b = relay.const(np.ones([1, 1, 1]), dtype=\"int64\", span=_create_span(\"i\"))\n",
    "    t = relay.Tuple([a, b], span=_create_span(\"i\"))\n",
    "    i = relay.TupleGetItem(t, 0, span=_create_span(\"i\"))\n",
    "    return relay.Function([], i)\n",
    "\n",
    "_verify_structural_equal_with_span(_res(), _golden())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.Let` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    x = set_span(relay.Var(\"x\"), \"x_var\")\n",
    "    c_1 = relay.const(np.ones(10))\n",
    "    add = relay.add(x, x)\n",
    "    body = set_span(relay.Let(x, c_1, add), \"let\")\n",
    "\n",
    "    c_2 = set_span(relay.const(np.zeros(10)), \"zeros\")\n",
    "    y = set_span(relay.add(body, c_2), \"add_2\")\n",
    "    return relay.Function([x], y)\n",
    "\n",
    "def _golden():\n",
    "    x = relay.Var(\"x\", span=_create_span(\"x_var\"))\n",
    "    c_1 = relay.const(np.ones(10), span=_create_span(\"let\"))\n",
    "    add = _set_span(relay.add(x, x), \"let\")\n",
    "    body = relay.Let(x, c_1, add, span=_create_span(\"let\"))\n",
    "\n",
    "    c_2 = relay.const(np.zeros(10), span=_create_span(\"zeros\"))\n",
    "    y = _set_span(relay.add(body, c_2), \"add_2\")\n",
    "    return relay.Function([x], y)\n",
    "\n",
    "_verify_structural_equal_with_span(_res(), _golden())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.If` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    x = set_span(relay.var(\"x\", shape=[], dtype=\"float32\"), \"x_var\")\n",
    "    y = set_span(relay.var(\"y\", shape=[], dtype=\"float32\"), \"y_var\")\n",
    "    eq = relay.equal(x, y)\n",
    "\n",
    "    true_branch = set_span(relay.add(x, y), \"true_branch\")\n",
    "    false_branch = relay.subtract(x, y)\n",
    "    ife = set_span(relay.If(eq, true_branch, false_branch), \"if\")\n",
    "    return relay.Function([x, y], ife)\n",
    "\n",
    "def _golden():\n",
    "    x = relay.var(\"x\", shape=[], dtype=\"float32\", span=_create_span(\"x_var\"))\n",
    "    y = relay.var(\"y\", shape=[], dtype=\"float32\", span=_create_span(\"y_var\"))\n",
    "    eq = _set_span(relay.equal(x, y), \"if\")\n",
    "\n",
    "    true_branch = _set_span(relay.add(x, y), \"true_branch\")\n",
    "    false_branch = _set_span(relay.subtract(x, y), \"if\")\n",
    "    ife = relay.If(eq, true_branch, false_branch, span=_create_span(\"if\"))\n",
    "    return relay.Function([x, y], ife)\n",
    "\n",
    "_verify_structural_equal_with_span(_res(), _golden())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试 `relay.Function` Span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _res():\n",
    "    x = set_span(relay.var(\"x\", shape=(1, 64, 56, 56)), \"x_var\")\n",
    "    w = relay.const(np.ones([64, 64, 3, 3]), dtype=\"int64\")\n",
    "    y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))\n",
    "    f = set_span(relay.Function([x], y), \"func\")\n",
    "    return f\n",
    "\n",
    "def _golden():\n",
    "    x = relay.var(\"x\", shape=(1, 64, 56, 56), span=_create_span(\"x_var\"))\n",
    "    w = relay.const(np.ones([64, 64, 3, 3]), dtype=\"int64\", span=_create_span(\"func\"))\n",
    "    y = _set_span(\n",
    "        relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), \"func\"\n",
    "    )\n",
    "    f = relay.Function([x], y, span=_create_span(\"func\"))\n",
    "    return f\n",
    "\n",
    "_verify_structural_equal_with_span(_res(), _golden())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = relay.var(\"x\", shape=(1, 64, 56, 56), span=_create_span(\"x_var\"))\n",
    "w = relay.const(np.ones([64, 64, 3, 3]), dtype=\"int64\", span=_create_span(\"func\"))\n",
    "y = _set_span(\n",
    "    relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), \"func\"\n",
    ")\n",
    "f = relay.Function([x], y, span=_create_span(\"func\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {\n",
      "  nn.conv2d(%x, meta[relay.Constant][0] /* span=func:0:0 */, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) /* span=func:0:0 */\n",
      "} /* span=func:0:0 */\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
